/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
 */
#ifndef LUT16DEP_HELPER_H
#define LUT16DEP_HELPER_H

using std::vector;

vector<uint8_t> CreatePackedDataset(const DenseDataset<uint8_t> &hashed_database)
{
    vector<uint8_t> packed_dataset;
    if (hashed_database.empty()) {
        return packed_dataset;
    }

    DimensionIndex real_num_blocks = hashed_database[0].nonzero_entries();
    DimensionIndex num_blocks = (hashed_database[0].nonzero_entries() + 1) & (~1);  // make sure is even
    packed_dataset.resize(num_blocks * ((hashed_database.size() + 31) & (~31)) / 2); // divisible by 31+1=32
    DatapointIndex k = 0;
    const size_t kUnrollBy = 32; // every 32 datapoints as a unit
    const size_t kPackedTo = 16; // kUnrollBy / 2 = 16 codes per dim
    const size_t kDims = 2; // mix 2 dims together, dim & dim + add
    const size_t kLeftShift = 16; // 2 ^ 4 = 16, left shift 4 bits
    const uint8_t perm0[16] = {0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30}; // kPackedTo
    const uint8_t add = 1;
    const size_t num_dp = hashed_database.size();
    
    for (; k < hashed_database.size() / kUnrollBy; ++k) {
        size_t start = k * kPackedTo * num_blocks; // kUnrollBy / 2 = 16 codes per dim
        for (size_t j = 0; j < num_blocks / kDims; ++j) { // combines every neighbouring 2 dims (blocks)
            for (size_t m = 0; m < kPackedTo; m++) {
                uint8_t u0 = hashed_database[k * kUnrollBy + perm0[m]].values()[kDims * j];
                uint8_t u1 = hashed_database[k * kUnrollBy + perm0[m] + add].values()[kDims * j];
                uint8_t u2 = (real_num_blocks > kDims * j + 1 ?
                    hashed_database[k * kUnrollBy + perm0[m]].values()[kDims * j + 1] : 0);
                uint8_t u3 = (real_num_blocks > kDims * j + 1 ?
                    hashed_database[k * kUnrollBy + perm0[m] + add].values()[kDims * j + 1] : 0);
                
                packed_dataset[start + j * kPackedTo * kDims + kDims * m] = u1 * kLeftShift + u0;
                packed_dataset[start + j * kPackedTo * kDims + kDims * m + 1] = u3 * kLeftShift + u2;
            }
        }
    }
    
    if (k * kUnrollBy < hashed_database.size()) {
        size_t start = k * kPackedTo * num_blocks;
        for (size_t j = 0; j < num_blocks / kDims; ++j) {
            auto checkOOB = [&](DimensionIndex dp_idx, size_t block_id) { // check if out of bounds
                return (dp_idx >= num_dp || block_id >= real_num_blocks);
            };
            for (size_t m = 0; m < kPackedTo; m++) {
                DatapointIndex dp_idx = k * kUnrollBy + perm0[m];
                uint8_t u0 = checkOOB(dp_idx, kDims * j) ? 0 : hashed_database[dp_idx].values()[kDims * j];
                uint8_t u2 = checkOOB(dp_idx, kDims * j + 1) ? 0 : hashed_database[dp_idx].values()[kDims * j + 1];

                dp_idx = k * kUnrollBy + perm0[m] + add;
                uint8_t u1 = checkOOB(dp_idx, kDims * j) ? 0 :hashed_database[dp_idx].values()[kDims * j];
                uint8_t u3 = checkOOB(dp_idx, kDims * j + 1) ? 0 : hashed_database[dp_idx].values()[kDims * j + 1];

                packed_dataset[start + j * kPackedTo * kDims + kDims * m] = u1 * kLeftShift + u0;
                packed_dataset[start + j * kPackedTo * kDims + kDims * m + 1] = u3 * kLeftShift + u2;
            }
        }
    }

    return packed_dataset;
}

#endif